# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

#################################################

# Simone Dussi
# Rycroft Group            
# SEAS, Harvard University 
# 
# code used in
# van der Hoeven et al., "Quantitative 3D Characterization of Catalytic Nanoparticle-Support Interfaces"

# Analysis of a region of interest (ROI) previously identified and stored
# SEGMENTATION   
# -> goal is to assign voxels to particle(s) P, support (S), void/pore (V) 
# - image is denoised
# - segmentation of P vs S+V, using thresholding algorithm
# - segmentation of P+S vs V, using watershed algorithm (with inputs from multi-phase thresholding)
# - combine segmentations to obtain P / S / V
# - clean segmentation by removing spurious support layer around particle (remove boundary + opening)
# QUANTIFICATION
# -> from cleaned segmentation, calculate particle size and exposed particle surface 
# - identify particle of interest (POI), so analysis is performed only on one particle per ROI
# - based on boundaries between P-S and P-V, count voxels of exposed and embedded particle surface


import os
import numpy as np
import matplotlib.pyplot as plt

from scipy import ndimage as ndi
from skimage.segmentation import watershed, find_boundaries
from skimage.morphology import ball, remove_small_objects, opening
from skimage.filters import threshold_yen, threshold_multiotsu, sobel  
from skimage.restoration import denoise_tv_chambolle
from skimage.measure import label, find_contours
from skimage.color import gray2rgb

### Some auxiliary functions for plotting/saving

# for plotting single panel
def plt_panel (ax,i,img_to_plot,title_label):
    ax[i].imshow(img_to_plot, cmap="gray")
    ax[i].set_axis_off()
    ax[i].set_title(title_label, fontsize=10)       

# for plotting contours
def plt_cntr (ax,i,contours,clr='y',lw=0.5):
    for n, contour in enumerate(contours):
        ax[i].plot(contour[:, 1], contour[:, 0], linewidth=lw, color=clr)

# for plotting slices of a list of 3D images
def general_plot_slices (list_images,label_images,Nslices=30,orientation=2,
                         show_plots=True,save_plots=False,filename=''):
    if orientation==0:
        im_size=np.shape(list_images[0])[0]
    elif orientation==1:
        im_size=np.shape(list_images[0])[1]
    elif orientation==2:
        im_size=np.shape(list_images[0])[2]
    panelsize=3
    Nimages=np.shape(list_images)[0]
    if Nimages==1:
        pltcols=5
    if Nimages>2:
        pltcols=2
    if Nimages>3:
        pltcols=1
    slicemin=int(0.5*im_size-Nslices/2)
    slicemax=int(0.5*im_size+Nslices/2)        
    figrows=int((slicemax-slicemin)/pltcols)
    figcols=Nimages*pltcols
    fig, axes = plt.subplots(figrows, figcols, figsize=(figcols*panelsize, figrows*panelsize))
    ax = axes.flatten()
    iax=0
    for slicevalue in range(slicemin,slicemax,pltcols):
        for j in range(pltcols):
            svalue=slicevalue+j
            for k in range(Nimages):
                if orientation==0:
                    toplot=np.copy(list_images[k][:,:,svalue])
                    tolab=str(label_images[k])+" x="+str(svalue)
                elif orientation==1:
                    toplot=np.copy(list_images[k][:,svalue,:])
                    tolab=str(label_images[k])+" y="+str(svalue)
                elif orientation==2:
                    toplot=np.copy(list_images[k][svalue,:,:])
                    tolab=str(label_images[k])+" z="+str(svalue)
                    
                plt_panel(ax, iax, toplot, tolab )
                iax+=1 
            
    fig.tight_layout()
    if show_plots==True:
        plt.show()
    if save_plots==True:
        plt.savefig(filename,dpi=200)
    plt.close()


# some plots to show procedure steps
# list_images="orig", "supp_seg", clean_seg", "exp/emb"
def procedure_plots(list_images,Nslices=20,show_plots=True,
                    save_plots=False,filename='./procedure_plots.png'):
    Nimages=len(list_images)
    im_size=np.shape(list_images[0])[2]
    panelsize=3
    slicemin=int(0.5*im_size-Nslices/2)
    slicemax=int(0.5*im_size+Nslices/2)        
    figrows=int((slicemax-slicemin))
    figcols=6
    fig, axes = plt.subplots(figrows, figcols, figsize=(figcols*panelsize, figrows*panelsize))
    ax = axes.flatten()
    iax=0
    for slicevalue in range(slicemin,slicemax,1):
        svalue=slicevalue
        
        toplot=np.copy(list_images[0][svalue,:,:])
        tolab="original z="+str(svalue)
        plt_panel(ax, iax, toplot, tolab )
        iax+=1 
        
        toplot=np.copy(list_images[0][svalue,:,:])
        tolab="particle seg"
        plt_panel(ax, iax, toplot, tolab )
        cont = find_contours(np.copy(list_images[2][svalue,:,:]), 1.5)
        plt_cntr (ax, iax, cont, 'y', lw=1.5)         
        iax+=1
        
        toplot=np.copy(list_images[0][svalue,:,:])
        plt_panel(ax, iax, toplot, tolab )
        cont = find_contours(np.copy(list_images[1][svalue,:,:]), 0.5)
        plt_cntr (ax, iax, cont, 'r', lw=1.0)         
        tolab="support seg"
        plt_panel(ax, iax, toplot, tolab )
        iax+=1

        toplot=np.copy(-list_images[2][svalue,:,:])
        tolab="cleaned seg"
        plt_panel(ax, iax, toplot, tolab )
        iax+=1 
        
        toplot=np.copy(list_images[0][svalue,:,:])
        tolab="supp/part"
        plt_panel(ax, iax, toplot, tolab )
        cont = find_contours(np.copy(list_images[2][svalue,:,:]), 0.5)
        plt_cntr (ax, iax, cont, 'r', lw=1.0)     
        cont = find_contours(np.copy(list_images[2][svalue,:,:]), 1.5)
        plt_cntr (ax, iax, cont, 'y', lw=1.5)     
        iax+=1 
        
        toplot=np.copy(list_images[0][svalue,:,:])
        tolab="exp/emb"
        plt_panel(ax, iax, toplot, tolab )
        cont = find_contours(np.copy(list_images[3][svalue,:,:]), 19.0)
        plt_cntr (ax, iax, cont, 'g', lw=1.5)     
        cont = find_contours(np.copy(list_images[3][svalue,:,:]), 29.0)
        plt_cntr (ax, iax, cont, 'y', lw=1.5)     
        iax+=1
        
    fig.tight_layout()
    if save_plots==True:
        plt.savefig(filename,dpi=200)
    if show_plots==True:
        plt.show()
    plt.close()
    
# list images: "original", "cleaned_segmentation", "exp/emb highlighted"
# plotted is original with 4 type of contours (support, all particles, POI exp/emb)
def save_summary_slices_v1(list_images,properties,Nslices=20,show_plots=False):
    im_size=np.shape(list_images[0])[2]
    panelsize=4
    pltcols=5
    slicemin=int(0.5*im_size-Nslices/2)
    slicemax=int(0.5*im_size+Nslices/2)        
    figrows=int((slicemax-slicemin)/pltcols)
    figcols=pltcols
    fig, axes = plt.subplots(figrows, figcols, figsize=(figcols*panelsize, figrows*panelsize))
    ax = axes.flatten()
    iax=0
    for slicevalue in range(slicemin,slicemax,pltcols):
        for j in range(pltcols):
            svalue=slicevalue+j
            toplot=np.copy(list_images[0][svalue,:,:])
            tolab="z="+str(svalue)
            plt_panel(ax, iax, toplot, tolab )
            cont = find_contours(np.copy(list_images[1][svalue,:,:]), 0.5)
            plt_cntr (ax, iax, cont, 'r', lw=0.5)    #to plot support
            cont = find_contours(np.copy(list_images[1][svalue,:,:]), 1.5)
            plt_cntr (ax, iax, cont, 'C1', lw=1.5)  #to plot also other particles
            cont = find_contours(np.copy(list_images[2][svalue,:,:]), 19.0)
            plt_cntr (ax, iax, cont, 'g', lw=1.5)    #to plot all POI
            cont = find_contours(np.copy(list_images[2][svalue,:,:]), 29.0)
            plt_cntr (ax, iax, cont, 'y', lw=2.0)    #to plot POI embedded
            iax+=1 
            
    fig.tight_layout()
    fig.subplots_adjust(top=0.95)
    plottitle="ROI={:d} diam={:.2f} nm %S_exp={:.1f}".format(int(properties[0]),properties[1],properties[2])
    fig.suptitle(plottitle)
    plt.savefig("./summary1_ROI_{:d}.png".format(int(properties[0])),dpi=200)
    if show_plots == True:
        plt.show()
    plt.close()
    

# list images: "original", "cleaned_segmentation", "exp/emb highlighted"
# plotted is original with pixels colored if exposed/embedded
# optionally, we can also plot the support contour
def save_summary_slices_v2(list_images,properties,Nslices=20,show_plots=False):
    im_size=np.shape(list_images[0])[2]
    panelsize=4
    pltcols=5
    slicemin=int(0.5*im_size-Nslices/2)
    slicemax=int(0.5*im_size+Nslices/2)        
    figrows=int((slicemax-slicemin)/pltcols)
    figcols=pltcols
    fig, axes = plt.subplots(figrows, figcols, figsize=(figcols*panelsize, figrows*panelsize))
    ax = axes.flatten()
    iax=0
    for slicevalue in range(slicemin,slicemax,pltcols):
        for j in range(pltcols):
            svalue=slicevalue+j
            tolab="z="+str(svalue)
            ## for support contour:
            #cont = find_contours(np.copy(list_images[1][svalue,:,:]), 0.5)
            #plt_cntr (ax, iax, cont, 'r', lw=1.0) 
            toplot=gray2rgb(np.copy(list_images[0][svalue,:,:]))
            poi_ee=np.copy(list_images[2][svalue,:,:]) 
            toplot[poi_ee==20]=(0,1.0,0)   #exp
            toplot[poi_ee==30]=(1.0,1.0,0) #emb
            plt_panel(ax, iax, toplot, tolab )     
            iax+=1 
            
    fig.tight_layout()
    fig.subplots_adjust(top=0.95)
    plottitle="ROI={:d} diam={:.2f} nm %S_exp={:.1f}".format(int(properties[0]),properties[1],properties[2])
    fig.suptitle(plottitle)
    plt.savefig("./summary2_ROI_{:d}.png".format(int(properties[0])),dpi=200)
    if show_plots == True:
        plt.show()
    plt.close()



### Load data for region of interest (ROI)

def load_data (idROI):
    filename_ROIdata="./ROIdata_"+str(idROI)
    flatteneddata=np.loadtxt(filename_ROIdata)
    raw_img=flatteneddata.reshape(150,150,150)
    # rescale intensity 
    min_img=np.min(raw_img)
    max_img=np.max(raw_img)
    rescaled_img=(raw_img-min_img)/(max_img-min_img)
    return rescaled_img

### Segmentation + Quantification

def analyze_ROI (ROIid, voxel_resolution, verbose=0, show_procedure_plots=False, show_summary_plots=False):
        
    ### Load data
    raw_img=load_data(ROIid)

    if verbose>0:
        print("loaded ROI=",ROIid)
       
    ### Apply filter
    filt_img=denoise_tv_chambolle(raw_img,weight=0.1)   
    
    ### Particle segmentation  (P vs S+V)
    thresholdYE = threshold_yen(filt_img)
    segmented = filt_img >= thresholdYE
    # some cleaning (typically not necessary)
    filled_segmented=ndi.binary_fill_holes(1-segmented)
    P_seg=1.0-remove_small_objects(filled_segmented,10)
    
    ### Segmentation P+S vs V
    # multiOtsu thresholding (to be used to define markers)
    thresholdsMO = threshold_multiotsu(filt_img)
    # define values for markes (also simpler ways usually work)
    delta_thr=0.025*(np.max(filt_img)-np.min(filt_img))
    thr_lo=thresholdsMO[1]-delta_thr
    thr_hi=thresholdsMO[1]+delta_thr
    # marker-based watershed 
    markers=np.zeros_like(filt_img)
    markers[filt_img < thr_lo]=1
    markers[filt_img > thr_hi]=2
    elevation_map=sobel(filt_img)
    segmented=watershed(elevation_map,markers)
    # some cleaning
    filled_segmented=ndi.binary_fill_holes(1-segmented)
    S_seg=1.0-opening(filled_segmented,ball(3))
    
    ### Combine the two partial segmentations -> P + S + V
    PS_seg = np.copy(S_seg)
    PS_seg [ P_seg==0 ] = 2.0
    
    if verbose>0:
        print("segmented, now cleaning and counting...")
    
    ### Clean combined segmentation by removing spurious (wrongly segmented) 
    ### thin support layer around particle
    # find support boundaries
    bnd_S=find_boundaries(S_seg,connectivity=1,mode='thick')
    # remove bnd_S
    removed_bndS=np.copy(PS_seg)
    removed_bndS[bnd_S==True]=0
    removed_bndS[removed_bndS==2]=0
    # clean S by performing opening (so spurious parts of the support will become void)
    cleaned_support=opening(removed_bndS,ball(1))
    # re-combine with particle segmentation
    cleaned_seg=2*(1-np.copy(P_seg))
    cleaned_seg [ cleaned_support==1 ]=1
     
    ### Identify particle of interest (POI)
    # label all particles
    labels = label(1-P_seg)
    # get the label of the ROI center, that should correspond to the POI
    # (except for peculiar cases where the ROI was somehow ill-defined...)
    labPOI=labels[int(0.5*labels.shape[0]),int(0.5*labels.shape[1]),int(0.5*labels.shape[2])]
    
    # Note: it is also possible to extract the meshes associated to the particles, 
    # by using marching cubes (of scikit-image) and e.g. the trimesh package. 
    # Such meshes can be used to both identify the POI,
    # using the mesh centroid closest to the ROI center, 
    # and subsquently calculate various particle quantities.
    # For simplicity, we do not include such analysis here
    
    # keep only the POI in the P_segmentation
    onlyPOI=np.copy(P_seg)
    onlyPOI [ labels==labPOI ]=3
    onlyPOI [ onlyPOI != 3]=0
    onlyPOI [ onlyPOI == 3]=1
    
    ### Calculate embedding by voxel counting
    # find boundary of clean support
    bnd_clean_S=find_boundaries(cleaned_support,connectivity=1,mode='thick')
    # find boundary of POI
    bnd_P=find_boundaries(onlyPOI,connectivity=1,mode='thick')
    # identify where both boundaries are the same
    both_bnds=bnd_clean_S*bnd_P
    # combine and highlight exposed/embedded surface of POI
    seg_exp_emb=np.copy(cleaned_seg)
    seg_exp_emb[onlyPOI == 1]=3
    seg_exp_emb[bnd_clean_S == True]=10
    seg_exp_emb[bnd_P == True]=20
    seg_exp_emb[both_bnds]=30
    # count voxels
    fl=seg_exp_emb.flatten()
    Pvol=0
    Psurf_tot=0
    Psurf_exp=0
    Psurf_emb=0
    for i in range(len(fl)):
        if fl[i]==3:
            Pvol+=1
        elif fl[i]==30:
            Psurf_emb+=1
        elif fl[i]==20:
            Psurf_exp+=1
        if fl[i]==20 or fl[i]==30:
            Psurf_tot+=1
    if Psurf_tot != Psurf_emb+Psurf_exp:
        print("something wrong with surface voxel counting")
        
    # calculate properties
    eff_diam_vox=pow(6.0/np.pi*Pvol,1./3.) #assume spherical shape
    eff_diam_nm=eff_diam_vox*voxel_resolution
    perc_exp_surf=100.0*Psurf_exp/(Psurf_exp+Psurf_emb)
    
    if verbose>0:
        print("saving...")
    
    ### Plot procedure steps, if defined
    if show_procedure_plots==True:
        #ls_im=[raw_img,P_seg,S_seg,PS_seg,cleaned_seg,seg_exp_emb]
        #ls_lb=['original','Pseg','Sseg','init_seg','cleaned_seg','exp/emb']
        #general_plot_slices(ls_im,ls_lb,Nslices=1,orientation=2)        
    
        # procedure plots: list_images="orig", "supp_seg", clean_seg", "exp/emb"
        ls_im=[raw_img,S_seg,cleaned_seg,seg_exp_emb]
        procedure_plots(ls_im,Nslices=1,show_plots=True,save_plots=True)

    ### Save properties and list of images for visual inspection
    ls_im=[raw_img,cleaned_seg,seg_exp_emb]
    properties=np.asarray([ROIid,eff_diam_nm,perc_exp_surf])
    #save_summary_slices_v1(ls_im,properties,Nslices=25,show_plots=show_summary_plots)
    save_summary_slices_v2(ls_im,properties,Nslices=25,show_plots=show_summary_plots)
    
    if verbose>0:
        print("done...")
    
    return properties



### Run example
voxel_resolution=0.77085156  #in nm
idROI=103
prop = analyze_ROI(idROI, voxel_resolution, verbose=1, show_procedure_plots=False, show_summary_plots=False)
